source(here("analysis", "visibility", "analysis_config.R"))

g <- readRDS(here("data", "output", "visibility_analysis.rds"))
library(survey)
source(here("analysis", "visibility", "processing", "survey_weights.R"))

message("Diagnose weights for complete analytic sample.")
s.analytic <- svydesign(ids = ~1, weights = ~wt_all_trim, data = g)
acs_base <- prepare_acs_targets(sample_size = 1) # setting to 1 will return proportions

# Create comprehensive comparison table
comparison_results <- list()

# 1. Race comparison
race_weighted <- svytable(~race_acs, s.analytic) / sum(weights(s.analytic))
race_unweighted <- prop.table(table(g$race_acs))
race_acs <- setNames(acs_base$race$Freq, acs_base$race$race_acs)

race_comparison <- data.frame(
  variable = "Race",
  category = names(race_weighted),
  unweighted_survey = as.numeric(race_unweighted[names(race_weighted)]),
  weighted_survey = as.numeric(race_weighted),
  acs_population = as.numeric(race_acs[names(race_weighted)]),
  stringsAsFactors = FALSE
)

# 2. Income comparison  
income_weighted <- svytable(~income_acs, s.analytic) / sum(weights(s.analytic))
income_unweighted <- prop.table(table(g$income_acs))
income_acs <- setNames(acs_base$income$Freq, acs_base$income$income_acs)

income_comparison <- data.frame(
  variable = "Income",
  category = names(income_weighted),
  unweighted_survey = as.numeric(income_unweighted[names(income_weighted)]),
  weighted_survey = as.numeric(income_weighted),
  acs_population = as.numeric(income_acs[names(income_weighted)]),
  stringsAsFactors = FALSE
)

# 3. Region comparison
region_weighted <- svytable(~region, s.analytic) / sum(weights(s.analytic))
region_unweighted <- prop.table(table(g$region))
region_acs <- setNames(acs_base$region$Freq, acs_base$region$region)

region_comparison <- data.frame(
  variable = "Region",
  category = names(region_weighted),
  unweighted_survey = as.numeric(region_unweighted[names(region_weighted)]),
  weighted_survey = as.numeric(region_weighted),
  acs_population = as.numeric(region_acs[names(region_weighted)]),
  stringsAsFactors = FALSE
)

# 4. Age x Education x Gender joint distribution
joint_weighted_table <- svytable(~age_acs+edu_acs+gender, s.analytic) / sum(weights(s.analytic))
joint_unweighted_table <- prop.table(table(g$age_acs, g$edu_acs, g$gender))

# Convert joint tables to vectors for comparison
joint_weighted_vec <- as.vector(joint_weighted_table)
joint_unweighted_vec <- as.vector(joint_unweighted_table)

# Create names for joint combinations
joint_names <- apply(expand.grid(
  age_acs = dimnames(joint_weighted_table)$age_acs,
  edu_acs = dimnames(joint_weighted_table)$edu_acs,
  gender = dimnames(joint_weighted_table)$gender
), 1, function(r) paste(r, collapse = " | "))

names(joint_weighted_vec) <- joint_names
names(joint_unweighted_vec) <- joint_names

# Create ACS joint vector with proper names
acs_joint_vec <- setNames(acs_base$joint$Freq, 
                         apply(acs_base$joint[,c("age_acs", "edu_acs", "gender")], 1, 
                               function(r) paste(r, collapse = " | ")))

joint_comparison <- data.frame(
  variable = "Age x Education x Gender",
  category = joint_names,
  unweighted_survey = as.numeric(joint_unweighted_vec),
  weighted_survey = as.numeric(joint_weighted_vec),
  acs_population = as.numeric(acs_joint_vec[joint_names]),
  stringsAsFactors = FALSE
)

# Combine all comparisons into single table
all_comparisons <- rbind(race_comparison, income_comparison, region_comparison, joint_comparison)

# Add calculated differences
all_comparisons$abs_diff_weighted_vs_acs <- abs(all_comparisons$weighted_survey - all_comparisons$acs_population)
all_comparisons$abs_diff_unweighted_vs_acs <- abs(all_comparisons$unweighted_survey - all_comparisons$acs_population)

# Display the comprehensive table
print("=== COMPREHENSIVE COMPARISON TABLE ===")
print(all_comparisons)

# Summary statistics
cat("\nSummary of absolute differences:\n")
cat("Weighted vs ACS - Mean:", round(mean(all_comparisons$abs_diff_weighted_vs_acs, na.rm = TRUE), 4), 
    "Max:", round(max(all_comparisons$abs_diff_weighted_vs_acs, na.rm = TRUE), 4), "\n")
cat("Unweighted vs ACS - Mean:", round(mean(all_comparisons$abs_diff_unweighted_vs_acs, na.rm = TRUE), 4), 
    "Max:", round(max(all_comparisons$abs_diff_unweighted_vs_acs, na.rm = TRUE), 4), "\n")

# Create LaTeX table output
library(xtable)

# Compute group breaks from original variable column BEFORE transformation
# This identifies where each variable type ends (Race, Income, Region, Joint)
variable_groups <- cumsum(c(TRUE, diff(match(all_comparisons$variable, 
                                             unique(all_comparisons$variable))) != 0))
group_breaks <- which(c(FALSE, diff(variable_groups) != 0))

# Create a formatted version of the table for LaTeX
latex_table <- all_comparisons %>%
  mutate(
    # Round all numeric columns to 2 decimal places
    unweighted_survey = round(unweighted_survey, 2),
    weighted_survey = round(weighted_survey, 2),
    acs_population = round(acs_population, 2),
    abs_diff_weighted_vs_acs = round(abs_diff_weighted_vs_acs, 2),
    abs_diff_unweighted_vs_acs = round(abs_diff_unweighted_vs_acs, 2),
    # Combine variable and category names to save space
    variable_category = case_when(
      variable == "Race" ~ paste0("Race: ", category),
      variable == "Income" ~ paste0("Income: ", category),
      variable == "Region" ~ paste0("Region: ", category),
      variable == "Age x Education x Gender" ~ {
        # Clean up joint distribution names
        clean_cat <- gsub(" \\| ", " × ", category)
        # Make even more compact for joint distributions
        clean_cat <- gsub("18 to 24 years", "18-24", clean_cat)
        clean_cat <- gsub("25 to 34 years", "25-34", clean_cat)
        clean_cat <- gsub("35 to 44 years", "35-44", clean_cat)
        clean_cat <- gsub("45 to 64 years", "45-64", clean_cat)
        clean_cat <- gsub("65 years and over", "65+", clean_cat)
        clean_cat <- gsub(" × 0 × ", " × No College × ", clean_cat)
        clean_cat <- gsub(" × 1 × ", " × College × ", clean_cat)
        clean_cat
      },
      TRUE ~ paste0(variable, ": ", category)
    )
  ) %>%
  select(variable_category, unweighted_survey, weighted_survey, acs_population, 
         abs_diff_weighted_vs_acs, abs_diff_unweighted_vs_acs)

# Generate LaTeX table
xt <- xtable(latex_table, 
             caption = "Comparison of Survey Distributions with ACS Population Benchmarks",
             label = "tab:weight_diagnosis",
             digits = 2)

# Set column names with line breaks for compactness
names(xt) <- c("Demographic Category", "Unweighted", "Weighted", "ACS Target", 
               "\\shortstack{Abs Diff\\\\(W-ACS)}", "\\shortstack{Abs Diff\\\\(U-ACS)}")

print(xt, 
      type = "latex",
      include.rownames = FALSE,
      caption.placement = "top",
      table.placement = "htbp",
      sanitize.text.function = function(x) x,  # Don't escape special characters
      size = "\\footnotesize",
      floating = TRUE,
      hline.after = c(-1, 0, group_breaks - 1, nrow(latex_table))
)

# Alternative: Create a more compact table focusing on summary by variable type
summary_table <- all_comparisons %>%
  group_by(variable) %>%
  summarise(
    n_categories = n(),
    mean_abs_diff_weighted = round(mean(abs_diff_weighted_vs_acs, na.rm = TRUE), 4),
    max_abs_diff_weighted = round(max(abs_diff_weighted_vs_acs, na.rm = TRUE), 4),
    mean_abs_diff_unweighted = round(mean(abs_diff_unweighted_vs_acs, na.rm = TRUE), 4),
    max_abs_diff_unweighted = round(max(abs_diff_unweighted_vs_acs, na.rm = TRUE), 4),
    .groups = "drop"
  )

xt_summary <- xtable(summary_table,
                     caption = "Summary of Weighting Performance by Variable Type",
                     label = "tab:weight_summary",
                     digits = 4)

names(xt_summary) <- c("Variable", "N Categories", "Mean |W-ACS|", "Max |W-ACS|", 
                       "Mean |U-ACS|", "Max |U-ACS|")

cat("\n\n=== SUMMARY TABLE ===\n")
print(xt_summary,
      type = "latex", 
      include.rownames = FALSE,
      caption.placement = "top",
      table.placement = "htbp",
      sanitize.text.function = function(x) x,
      floating = TRUE)

# Save both tables to pnas/tables directory
output_dir <- here::here("output", "pnas", "tables")
if (!dir.exists(output_dir)) {
  dir.create(output_dir, recursive = TRUE)
}

detailed_file <- file.path(output_dir, "tab_S3_weights.tex")
summary_file <- file.path(output_dir, "weights_diagnosis_summary.tex")

cat(print(xt, type = "latex", include.rownames = FALSE, 
          caption.placement = "top", table.placement = "htbp", 
          sanitize.text.function = function(x) x, size = "\\footnotesize", 
          floating = TRUE, print.results = FALSE,
          hline.after = c(-1, 0, group_breaks - 1, nrow(latex_table))),
    file = detailed_file)

# Diagnostic table (not in paper)
if (isFALSE(pnas)) {
  cat(print(xt_summary, type = "latex", include.rownames = FALSE,
            caption.placement = "top", table.placement = "htbp",
            sanitize.text.function = function(x) x, floating = TRUE, 
            print.results = FALSE),
      file = summary_file)
}

message("LaTeX tables saved to:")
message("  - ", detailed_file)
if (isFALSE(pnas)) {
  message("  - ", summary_file)
}